# import library for 이미지 가공
from PIL import Image
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
import cv2
# Data Load // Data 디렉토리 아래 파일이 있어야 함
digit_data_trval=np.load('Data/digit_data_TrVal-1.npz')
digit_data_test=np.load('Data/digit_data_10_Test_1.npz')
op_data_trval=np.load('Data/op_data_TrVal-1.npz')
op_data_test=np.load('Data/op_data_10_Test_1.npz')
#print(digit_data_trval['img'].shape, digit_data_trval['img'].dtype,digit_data_trval['label'].shape, digit_data_trval['label'].dtype)
#print(digit_data_test['img'].shape, digit_data_test['img'].dtype, digit_data_test['label'].shape, digit_data_test['label'].dtype)
#print(op_data_trval['img'].shape, op_data_trval['img'].dtype, op_data_trval['label'].shape, op_data_trval['label'].dtype)
#print(op_data_test['img'].shape, op_data_test['img'].dtype, op_data_test['label'].shape, op_data_test['label'].dtype)
# train set 이미지와 레이블 추출
op_trval_images = op_data_trval['img']
op_trval_labels = op_data_trval['label']
digit_trval_images = digit_data_trval['img']
digit_trval_labels = digit_data_trval['label']
# train set 이미지와 레이블 합치기
merged_trval_images = list(digit_trval_images) + list(op_trval_images)
merged_trval_labels = list(digit_trval_labels) + list(op_trval_labels)
# test set 이미지와 레이블 추출
op_test_images = op_data_test['img']
op_test_labels = op_data_test['label']
digit_test_images = digit_data_test['img']
digit_test_labels = digit_data_test['label']
# test 이미지와 레이블 합치기
merged_test_images = list(op_test_images) + list(digit_test_images)
merged_test_labels = list(op_test_labels) + list(digit_test_labels)
merged_trval_images = np.array(merged_trval_images)
merged_test_images = np.array(merged_test_images)
# 이미지를 1차원으로 변환
merged_trval_images_flat = merged_trval_images.reshape(merged_trval_images.shape[0], -1)
merged_test_images_flat = merged_test_images.reshape(merged_test_images.shape[0], -1)
#valid_labels = set(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])#, '+', '-', '/', 'x'])
valid_labels = set(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '/', 'x'])
# 숫자 및 연산 기호가 아닌 데이터의 인덱스 찾기
invalid_trval_indices = [i for i, label in enumerate(merged_trval_labels) if label not in valid_labels]
invalid_test_indices = [i for i, label in enumerate(merged_test_labels) if label not in valid_labels]
# 숫자 및 연산 기호가 아닌 데이터 제거
d_X_train = np.delete(merged_trval_images_flat, invalid_trval_indices, axis=0)
d_y_train = np.delete(merged_trval_labels, invalid_trval_indices)
d_X_test = np.delete(merged_test_images_flat, invalid_test_indices, axis=0)
d_y_test = np.delete(merged_test_labels, invalid_test_indices)
# 데이터는 셔플 되어 있다고 가정
from collections import defaultdict
# 각 레이블별 인덱스 추출
label_indices_train = defaultdict(list)
label_indices_test = defaultdict(list)
# Train set에서 각 레이블별 인덱스 추출
for idx, label in enumerate(d_y_train):
label_indices_train[label].append(idx)
# Test set에서 각 레이블별 인덱스 추출
for idx, label in enumerate(d_y_test):
label_indices_test[label].append(idx)
# 데이터 크기 반으로
size=10
# 각 레이블별로 train set은 400개, test set은 70개로 제한
selected_indices_train = []
selected_indices_test = []
for label, indices in label_indices_train.items():
selected_indices_train.extend(indices[:int(420*size)])
for label, indices in label_indices_test.items():
selected_indices_test.extend(indices[:int(70*size)])
# 선택된 인덱스로부터 데이터 구성
X_train = d_X_train[selected_indices_train]
y_train = d_y_train[selected_indices_train]
X_test = d_X_test[selected_indices_test]
y_test = d_y_test[selected_indices_test]
# 모델 선정용 데이터 shape 확인
print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
print("X_test shape:", X_test.shape)
print("y_test shape:", y_test.shape)
# 모델 선정용 데이터 절감 비율 확인
train_reduction_ratio = round(d_y_train.shape[0] / y_train.shape[0], 2)
test_reduction_ratio = round(d_y_test.shape[0] / y_test.shape[0], 2)
print("\ntrain 절감 비율 : ", train_reduction_ratio)
print("test 절감 비율 : ",test_reduction_ratio)